import numpy as np
import torch
import scipy
from torch import nn
from torch.distributions import MultivariateNormal
from typing import Tuple
from torch.distributions import Normal

def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def eval_vs(act_dim_true,index_set,extra):
     
    true_index = [i for i in range(act_dim_true)]
    p = act_dim_true
    n = extra
    tp = len([index for index in index_set if index in true_index])
    fp = len(index_set)-tp
    
    tpr = tp/p
    fpr = fp/(n+0.0000001)
    fdr = fp/(fp+tp+0.000001)
                                   
    return [tpr,fpr,fdr]


def create_mask(new_index,act_dim):

    # Create a tensor filled with zeros
    mask = torch.zeros(act_dim)

    # Set the elements at indices in the list to 1
    mask[new_index] = 1
    
    return mask


def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.

    input: 
        vector x, 
        [x0, 
         x1, 
         x2]

    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]



################### distributions


class LatticeStateDependentNoiseDistribution(object):
    """
    Distribution class of Lattice exploration.
    Paper: Latent Exploration for Reinforcement Learning https://arxiv.org/abs/2305.20065

    It creates correlated noise across actuators, with a covariance matrix induced by
    the network weights. Can improve exploration in high-dimensional systems.

    :param action_dim: Dimension of the action space.
    :param full_std: Whether to use (n_features x n_actions) parameters
        for the std instead of only (n_features,), defaults to True
    :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
        a positive standard deviation (cf paper). It allows to keep variance
        above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
        Defaults to False
    :param squash_output:  Whether to squash the output using a tanh function,
        this ensures bounds are satisfied, defaults to False
    :param learn_features: Whether to learn features for gSDE or not, defaults to False
        This will enable gradients to be backpropagated through the features, defaults to False
    :param epsilon: small value to avoid NaN due to numerical imprecision, defaults to 1e-6
    :param std_clip: clip range for the standard deviation, can be used to prevent extreme values,
        defaults to (1e-3, 1.0)
    :param std_reg: optional regularization to prevent collapsing to a deterministic policy,
        defaults to 0.0
    :param alpha: relative weight between action and latent noise, 0 removes the latent noise,
        defaults to 1 (equal weight)
    """
    def __init__(
        self,
        action_dim: int,
        latent_sde_dim: int,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        learn_features: bool = False,
        epsilon: float = 1e-6,
        std_clip: Tuple[float, float] = (1e-3, 1.0),
        std_reg: float = 0.0,
        alpha: float = 1,
    ):
        self.min_std, self.max_std = std_clip
        self.std_reg = std_reg
        self.alpha = alpha
        self.latent_sde_dim = latent_sde_dim 
        self.action_dim=action_dim
        self.full_std=full_std
        self.use_expln=use_expln
        self.squash_output=squash_output
        self.epsilon=epsilon
        self.learn_features=learn_features

    def get_std(self, log_std: torch.Tensor) -> torch.Tensor:
        """
        Get the standard deviation from the learned parameter
        (log of it by default). This ensures that the std is positive.

        :param log_std:
        :return:
        """
        # Apply correction to remove scaling of action std as a function of the latent
        # dimension (see paper for details)
        log_std = log_std.clip(min=np.log(self.min_std), max=np.log(self.max_std))
        log_std = log_std - 0.5 * np.log(self.latent_sde_dim)

        if self.use_expln:
            # From gSDE paper, it allows to keep variance
            # above zero and prevent it from growing too fast
            below_threshold = torch.exp(log_std) * (log_std <= 0)
            # Avoid NaN: zeros values that are below zero
            safe_log_std = log_std * (log_std > 0) + self.epsilon
            above_threshold = (torch.log1p(safe_log_std) + 1.0) * (log_std > 0)
            std = below_threshold + above_threshold
        else:
            # Use normal exponential
            std = torch.exp(log_std)

        if self.full_std:
            assert std.shape == (
                self.latent_sde_dim,
                self.latent_sde_dim + self.action_dim,
            )
            corr_std = std[:, : self.latent_sde_dim]
            ind_std = std[:, -self.action_dim :]
        else:
            # Reduce the number of parameters:
            assert std.shape == (self.latent_sde_dim, 2), std.shape
            corr_std = torch.ones(self.latent_sde_dim, self.latent_sde_dim).to(log_std.device) * std[:, 0:1]
            ind_std = torch.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std[:, 1:]
        return corr_std, ind_std

    def sample_weights(self, log_std: torch.Tensor, batch_size: int = 1) -> None:
        """
        Sample weights for the noise exploration matrix,
        using a centered Gaussian distribution.

        :param log_std:
        :param batch_size:
        """
        corr_std, ind_std = self.get_std(log_std)
        self.corr_weights_dist = Normal(torch.zeros_like(corr_std), corr_std)
        self.ind_weights_dist = Normal(torch.zeros_like(ind_std), ind_std)

        # Reparametrization trick to pass gradients
        self.corr_exploration_mat = self.corr_weights_dist.rsample()
        self.ind_exploration_mat = self.ind_weights_dist.rsample()

        # Pre-compute matrices in case of parallel exploration
        self.corr_exploration_matrices = self.corr_weights_dist.rsample((batch_size,))
        self.ind_exploration_matrices = self.ind_weights_dist.rsample((batch_size,))


    def distribution(
        self,
        mean_actions: torch.Tensor,
        log_std: torch.Tensor,
        latent_sde: torch.Tensor,
    ):
        # Detach the last layer features because we do not want to update the noise generation
        # to influence the features of the policy
        self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
        if len(self._latent_sde.shape)<2:
            self._latent_sde = torch.unsqueeze(self._latent_sde,0)
        corr_std, ind_std = self.get_std(log_std)
        # print(self._latent_sde.shape)
        # print(corr_std.shape)
        latent_corr_variance = torch.mm(self._latent_sde**2, corr_std**2)  # Variance of the hidden state
        latent_ind_variance = torch.mm(self._latent_sde**2, ind_std**2) + self.std_reg**2  # Variance of the action

        # First consider the correlated variance
        sigma_mat = self.alpha**2 * (self.mean_actions_net.weight * latent_corr_variance[:, None, :]).matmul(
            self.mean_actions_net.weight.T
        )
        # Then the independent one, to be added to the diagonal
        sigma_mat[:, range(self.action_dim), range(self.action_dim)] += latent_ind_variance
        self.dis = MultivariateNormal(loc=mean_actions, covariance_matrix=sigma_mat, validate_args=False)
        
        return self.dis


    def entropy(self) -> torch.Tensor:
        return self.distribution.entropy()

    def get_noise(
        self,
        latent_sde: torch.Tensor,
        exploration_mat: torch.Tensor,
        exploration_matrices: torch.Tensor,
    ) -> torch.Tensor:
        latent_sde = latent_sde if self.learn_features else latent_sde.detach()
        # Default case: only one exploration matrix
        if len(latent_sde) == 1 or len(latent_sde) != len(exploration_matrices):
            return torch.mm(latent_sde, exploration_mat)
        # Use batch matrix multiplication for efficient computation
        # (batch_size, n_features) -> (batch_size, 1, n_features)
        latent_sde = latent_sde.unsqueeze(dim=1)
        # (batch_size, 1, n_actions)
        noise = torch.bmm(latent_sde, exploration_matrices)
        return noise.squeeze(dim=1)

    def sample(self) -> torch.Tensor:
        latent_noise = self.alpha * self.get_noise(self._latent_sde, self.corr_exploration_mat, self.corr_exploration_matrices)
        action_noise = self.get_noise(self._latent_sde, self.ind_exploration_mat, self.ind_exploration_matrices)
        actions = self.clipped_mean_actions_net(self._latent_sde + latent_noise) + action_noise

        return actions
    

